# Description:
#   Stubs for dynamically loading CUDA.

load(
    "@xla//xla/tsl/platform:rules_cc.bzl",
    "cc_library",
)
load(
    "@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
    "cuda_rpath_flags",
    "if_cuda_is_configured",
    "if_cuda_newer_than",
)
load(
    "//xla/tsl:tsl.default.bzl",
    "if_cuda_libs",
)
load("//xla/tsl/cuda:stub.bzl", "cuda_stub")

package(
    licenses = ["notice"],
)

cuda_stub(
    name = "cublas",
    srcs = ["cublas.symbols"],
)

cc_library(
    name = "cublas_stub",
    srcs = if_cuda_is_configured([
        "cublas_stub.cc",
        "cublas.tramp.S",
    ]),
    linkopts = if_cuda_is_configured(if_cuda_newer_than(
        "13_0",
        if_false = cuda_rpath_flags("nvidia/cublas/lib"),
        if_true = cuda_rpath_flags("nvidia/cu13/lib"),
    )),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["cublas.inc"],
    visibility = ["//visibility:public"],
    deps = if_cuda_is_configured([
        "@com_google_absl//absl/container:flat_hash_set",
        "@local_config_cuda//cuda:cuda_headers",
        "@tsl//tsl/platform:dso_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:load_library",
    ]),
)

alias(
    name = "cublas",  # buildifier: disable=duplicated-name
    actual = if_cuda_libs("@cuda_cublas//:cublas", ":cublas_stub"),
    visibility = ["//visibility:public"],
)

cuda_stub(
    name = "cublasLt",
    srcs = ["cublasLt.symbols"],
)

cc_library(
    name = "cublas_lt_stub",
    srcs = if_cuda_is_configured([
        "cublasLt_stub.cc",
        "cublasLt.tramp.S",
    ]),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["cublasLt.inc"],
    visibility = ["//visibility:public"],
    deps = if_cuda_is_configured([
        "@local_config_cuda//cuda:cuda_headers",
        "@tsl//tsl/platform:dso_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:load_library",
    ]),
)

alias(
    name = "cublas_lt",
    actual = if_cuda_libs("@cuda_cublas//:cublasLt", ":cublas_lt_stub"),
    visibility = ["//visibility:public"],
)

cuda_stub(
    name = "cuda",
    srcs = ["cuda.symbols"],
)

cc_library(
    name = "cuda",  # buildifier: disable=duplicated-name
    srcs = if_cuda_is_configured([
        "cuda_stub.cc",
        "cuda.tramp.S",
    ]),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["cuda.inc"],
    visibility = ["//visibility:public"],
    deps = if_cuda_is_configured([
        "@local_config_cuda//cuda:cuda_headers",
        "@tsl//tsl/platform:dso_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:load_library",
    ]),
)

cuda_stub(
    name = "cudart",
    srcs = ["cudart.symbols"],
)

cc_library(
    name = "cudart_stub",
    srcs = select({
        # include dynamic loading implementation only when if_cuda_is_configured and build dynamically
        "@xla//xla/tsl:is_cuda_enabled_and_oss": [
            "cudart.tramp.S",
            "cudart_stub.cc",
        ],
        "//conditions:default": [],
    }),
    linkopts = if_cuda_is_configured(if_cuda_newer_than(
        "13_0",
        if_false = cuda_rpath_flags("nvidia/cuda_runtime/lib"),
        if_true = cuda_rpath_flags("nvidia/cu13/lib"),
    )),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["cudart.inc"],
    visibility = ["//visibility:public"],
    deps = select({
        "@xla//xla/tsl:is_cuda_enabled_and_oss": [
            ":cuda",
            "@com_google_absl//absl/container:flat_hash_set",
            "@local_config_cuda//cuda:cuda_headers",
            "@tsl//tsl/platform:dso_loader",
            "@tsl//tsl/platform:load_library",
            "@tsl//tsl/platform:logging",
        ],
        "//conditions:default": [],
    }),
)

alias(
    name = "cudart",  # buildifier: disable=duplicated-name
    actual = if_cuda_libs("@cuda_cudart//:cudart", ":cudart_stub"),
    visibility = ["//visibility:public"],
)

cuda_stub(
    name = "cudnn",
    srcs = ["cudnn.symbols"],
)

cc_library(
    name = "cudnn_stub",
    srcs = if_cuda_is_configured([
        "cudnn_stub.cc",
        "cudnn.tramp.S",
    ]),
    linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/cudnn/lib")),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["cudnn.inc"],
    visibility = ["//visibility:public"],
    deps = if_cuda_is_configured([
        "@com_google_absl//absl/container:flat_hash_map",
        "@local_config_cuda//cuda:cudnn_header",
        "@tsl//tsl/platform:dso_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:load_library",
    ]),
)

alias(
    name = "cudnn",  # buildifier: disable=duplicated-name
    actual = if_cuda_libs("@cuda_cudnn//:cudnn", ":cudnn_stub"),
    visibility = ["//visibility:public"],
)

cc_library(
    name = "nccl_rpath_flags",
    linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nccl/lib")),
    visibility = ["//visibility:public"],
)

alias(
    name = "nccl_rpath",
    actual = if_cuda_libs("@cuda_nccl//:nccl", ":nccl_rpath_flags"),
    visibility = ["//visibility:public"],
)

cc_library(
    name = "tensorrt_rpath",
    linkopts = if_cuda_is_configured(cuda_rpath_flags("tensorrt")),
    visibility = ["//visibility:public"],
)

cuda_stub(
    name = "cufft",
    srcs = ["cufft.symbols"],
)

cc_library(
    name = "cufft_stub",
    srcs = if_cuda_is_configured([
        "cufft_stub.cc",
        "cufft.tramp.S",
    ]),
    linkopts = if_cuda_is_configured(if_cuda_newer_than(
        "13_0",
        if_false = cuda_rpath_flags("nvidia/cufft/lib"),
        if_true = cuda_rpath_flags("nvidia/cu13/lib"),
    )),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["cufft.inc"],
    visibility = ["//visibility:public"],
    deps = if_cuda_is_configured([
        "@local_config_cuda//cuda:cuda_headers",
        "@tsl//tsl/platform:dso_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:load_library",
    ]),
)

alias(
    name = "cufft",  # buildifier: disable=duplicated-name
    actual = if_cuda_libs("@cuda_cufft//:cufft", ":cufft_stub"),
    visibility = ["//visibility:public"],
)

cuda_stub(
    name = "cupti",
    srcs = ["cupti.symbols"],
)

cc_library(
    name = "cupti_stub",
    srcs = if_cuda_is_configured([
        "cupti_stub.cc",
        "cupti.tramp.S",
    ]),
    data = if_cuda_is_configured(["@local_config_cuda//cuda:cupti_dsos"]),
    linkopts = if_cuda_is_configured(if_cuda_newer_than(
        "13_0",
        if_false = cuda_rpath_flags("nvidia/cuda_cupti/lib"),
        if_true = cuda_rpath_flags("nvidia/cu13/lib"),
    )),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["cupti.inc"],
    visibility = ["//visibility:public"],
    deps = if_cuda_is_configured([
        "@local_config_cuda//cuda:cuda_headers",
        "@local_config_cuda//cuda:cupti_headers",
        "@tsl//tsl/platform:dso_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:load_library",
    ]),
)

alias(
    name = "cupti",  # buildifier: disable=duplicated-name
    actual = if_cuda_libs("@cuda_cupti//:cupti", ":cupti_stub"),
    visibility = ["//visibility:public"],
)

cuda_stub(
    name = "cusolver",
    srcs = ["cusolver.symbols"],
)

cc_library(
    name = "cusolver_stub",
    srcs = if_cuda_is_configured([
        "cusolver_stub.cc",
        "cusolver.tramp.S",
    ]),
    linkopts = if_cuda_is_configured(if_cuda_newer_than(
        "13_0",
        if_false = cuda_rpath_flags("nvidia/cusolver/lib"),
        if_true = cuda_rpath_flags("nvidia/cu13/lib"),
    )),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["cusolver.inc"],
    visibility = ["//visibility:public"],
    deps = if_cuda_is_configured([
        "@local_config_cuda//cuda:cuda_headers",
        "@tsl//tsl/platform:dso_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:load_library",
    ]),
)

alias(
    name = "cusolver",  # buildifier: disable=duplicated-name
    actual = if_cuda_libs("@cuda_cusolver//:cusolver", ":cusolver_stub"),
    visibility = ["//visibility:public"],
)

cuda_stub(
    name = "cusparse",
    srcs = ["cusparse.symbols"],
)

cc_library(
    name = "cusparse_stub",
    srcs = if_cuda_is_configured([
        "cusparse_stub.cc",
        "cusparse.tramp.S",
    ]),
    linkopts = if_cuda_is_configured(if_cuda_newer_than(
        "13_0",
        if_false = cuda_rpath_flags("nvidia/cusparse/lib"),
        if_true = cuda_rpath_flags("nvidia/cu13/lib"),
    )),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["cusparse.inc"],
    visibility = ["//visibility:public"],
    deps = if_cuda_is_configured([
        "@com_google_absl//absl/container:flat_hash_set",
        "@local_config_cuda//cuda:cuda_headers",
        "@tsl//tsl/platform:dso_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:load_library",
    ]),
)

alias(
    name = "cusparse",  # buildifier: disable=duplicated-name
    actual = if_cuda_libs("@cuda_cusparse//:cusparse", ":cusparse_stub"),
    visibility = ["//visibility:public"],
)

cuda_stub(
    name = "nccl",
    srcs = ["nccl.symbols"],
)

cc_library(
    name = "nccl_stub",  # buildifier: disable=duplicated-name
    srcs = if_cuda_is_configured([
        "nccl_stub.cc",
        "nccl.tramp.S",
    ]),
    linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nccl/lib")),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["nccl.inc"],
    visibility = ["//visibility:public"],
    deps = if_cuda_is_configured([
        "@com_google_absl//absl/container:flat_hash_set",
        "@local_config_cuda//cuda:cuda_headers",
        "@local_config_nccl//:nccl_headers",
        "@tsl//tsl/platform:dso_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:load_library",
    ]),
)

alias(
    name = "nccl",  # buildifier: disable=duplicated-name
    actual = if_cuda_libs("@cuda_nccl//:nccl", ":nccl_stub"),
    visibility = ["//visibility:public"],
)

cuda_stub(
    name = "nvshmem",
    srcs = ["nvshmem.symbols"],
)

cc_library(
    name = "nvshmem",  # buildifier: disable=duplicated-name
    srcs = if_cuda_is_configured([
        "nvshmem_stub.cc",
        "nvshmem.tramp.S",
    ]),
    # TODO: standalone pip wheel rpath
    linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nvshmem/lib")),
    local_defines = [
        "IMPLIB_EXPORT_SHIMS=1",
    ],
    textual_hdrs = ["nvshmem.inc"],
    visibility = ["//visibility:public"],
    deps = if_cuda_is_configured([
        "@tsl//tsl/platform:dso_loader",
        "@tsl//tsl/platform:logging",
        "@tsl//tsl/platform:load_library",
    ]),
)

alias(
    name = "nvshmem_stub",
    actual = if_cuda_libs("@nvidia_nvshmem//:nvshmem", ":nvshmem"),
    visibility = ["//visibility:public"],
)
